import torch
import torch.nn as nn
import torch.nn.functional as F
from attention_modules import FullContextEmbedding, CosineAttention, DotProductAttention


class ResidualBlock(nn.Module):
    """
    Residual block for feature extraction - same as in PrototypicalNetwork
    """
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, in_channels * 2, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm1d(in_channels * 2)
        self.conv2 = nn.Conv1d(in_channels * 2, in_channels, kernel_size=1, stride=1, padding=0)
        self.bn2 = nn.BatchNorm1d(in_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out


class MatchingNetwork(nn.Module):
    """
    Matching Networks implementation for few-shot learning
    """
    def __init__(self, input_size, hidden_size, output_size, nhead=8, num_layers=2, seq_len=8,
                 use_fce=True, attention_type='cosine'):
        super(MatchingNetwork, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.seq_len = seq_len
        self.use_fce = use_fce
        self.attention_type = attention_type

        # Feature extraction layers (same as PrototypicalNetwork)
        self.conv1 = nn.Conv1d(1, input_size, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm1d(input_size)

        # Residual blocks for feature extraction
        self.residual_blocks = nn.Sequential(
            ResidualBlock(input_size),
            ResidualBlock(input_size),
            ResidualBlock(input_size)
        )

        # Transformer encoder for feature extraction
        encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Embedding network
        self.fc1 = nn.Linear(seq_len * input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size // 4)
        self.fc4 = nn.Linear(hidden_size // 4, hidden_size // 2)
        self.fc5 = nn.Linear(hidden_size // 2, hidden_size // 2)
        self.fc6 = nn.Linear(hidden_size // 2, output_size)
        self.dropout = nn.Dropout(0.5)

        # Full Context Embedding for support set
        if self.use_fce:
            self.fce = FullContextEmbedding(output_size, hidden_size // 4)
            self.fce_output_size = hidden_size // 2  # bidirectional LSTM output
        else:
            self.fce_output_size = output_size

        # Attention mechanism
        attention_input_size = self.fce_output_size if self.use_fce else output_size
        if attention_type == 'cosine':
            self.attention = CosineAttention(attention_input_size)
        elif attention_type == 'dot_product':
            self.attention = DotProductAttention(attention_input_size)
        else:
            raise ValueError(f"Unknown attention type: {attention_type}")

    def embed(self, x):
        """
        Extract embeddings from input data
        """
        x = x.unsqueeze(1)  # Add channel dimension
        x = F.relu(self.conv1(x))
        x = F.relu(self.bn1(x))
        x = self.residual_blocks(x)

        # Transformer processing
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, input_size)
        x = self.transformer_encoder(x)
        x = x.permute(1, 0, 2).contiguous()  # (batch_size, seq_len, input_size)

        # Flatten and pass through fully connected layers
        x = x.view(x.size(0), -1)  # (batch_size, seq_len * input_size)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = self.fc6(x)  # (batch_size, output_size)

        return x

    def forward(self, support_set, support_labels, query_set, num_classes, num_support):
        """
        Forward pass for Matching Networks

        Args:
            support_set: (num_classes * num_support, input_dim)
            support_labels: (num_classes * num_support,)
            query_set: (num_query, input_dim)
            num_classes: number of classes
            num_support: number of support examples per class

        Returns:
            predictions: (num_query, num_classes)
        """
        # Extract embeddings
        support_embeddings = self.embed(support_set)  # (num_classes * num_support, output_size)
        query_embeddings = self.embed(query_set)  # (num_query, output_size)

        # Apply Full Context Embedding if enabled
        if self.use_fce:
            support_embeddings = self.fce(support_embeddings, support_labels, num_classes)

        # For each query, compute attention over support set
        num_query = query_embeddings.size(0)
        predictions = torch.zeros(num_query, num_classes, device=query_embeddings.device)

        # Handle dimension mismatch when FCE is used
        if self.use_fce:
            # Add a linear projection layer to match dimensions
            if not hasattr(self, 'query_projection'):
                self.query_projection = nn.Linear(query_embeddings.size(1), self.fce_output_size).to(query_embeddings.device)
            query_embeddings = self.query_projection(query_embeddings)

        for i, query_embedding in enumerate(query_embeddings):
            # Compute attention weights for this query
            attention_weights = self.attention(query_embedding, support_embeddings, support_labels)

            # Compute weighted prediction for each class
            for class_idx in range(num_classes):
                # Find support examples for this class
                class_mask = (support_labels == class_idx)
                class_weights = attention_weights[class_mask]

                # Sum of attention weights for this class
                predictions[i, class_idx] = class_weights.sum()

        return predictions


def matching_loss(predictions, query_labels):
    """
    Compute cross-entropy loss for matching networks

    Args:
        predictions: (num_query, num_classes) - attention-based predictions
        query_labels: (num_query,) - true labels

    Returns:
        loss: scalar loss value
    """
    # Apply log softmax to predictions
    log_predictions = F.log_softmax(predictions, dim=1)

    # Compute negative log likelihood loss
    loss = F.nll_loss(log_predictions, query_labels)

    return loss


def compute_accuracy(predictions, query_labels):
    """
    Compute classification accuracy

    Args:
        predictions: (num_query, num_classes)
        query_labels: (num_query,)

    Returns:
        accuracy: scalar accuracy value
    """
    _, predicted_classes = torch.max(predictions, dim=1)
    correct = (predicted_classes == query_labels).sum().item()
    total = query_labels.size(0)
    accuracy = correct / total
    return accuracy


def process_batch_matching(matching_net, batch, device, num_classes, num_support):
    """
    Process a batch for matching networks training/testing

    Args:
        matching_net: MatchingNetwork model
        batch: (data, labels) tuple
        device: torch device
        num_classes: number of classes
        num_support: number of support examples per class

    Returns:
        predictions: (num_query, num_classes)
        query_labels: (num_query,)
    """
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Split into support and query sets
    support_set = data[:num_classes * num_support]
    support_labels = labels[:num_classes * num_support]

    num_query = data.size(0) - num_classes * num_support
    query_set = data[num_classes * num_support:]
    query_labels = labels[num_classes * num_support:]

    # Forward pass
    predictions = matching_net(support_set, support_labels, query_set, num_classes, num_support)

    return predictions, query_labels
